Master's thesis case study 1¶
In [59]:
%load_ext autoreload
%autoreload 2
The autoreload extension is already loaded. To reload it, use: %reload_ext autoreload
In [60]:
import numpy
import torch
import seaborn
from adaptive_nof1 import *
from adaptive_nof1.policies import *
from adaptive_nof1.helpers import *
from adaptive_nof1.inference import *
from adaptive_nof1.metrics import *
from matplotlib import pyplot as plt
from adaptive_nof1.patient_explorer import show_patient_explorer
Simulation Study 1: Thompson vs UCB¶
In [61]:
# Initial parameters
block_length = 5
length = 6 * block_length
number_of_actions = 2
number_of_patients = 100
In [62]:
# Scenarios
class NormalModel(Model):
def __init__(self, patient_id, mean, variance):
self.rng = numpy.random.default_rng(patient_id)
self.mean = mean
self.variance = variance
self.patient_id = patient_id
def multivariate_normal_distribution(debug_data):
cov = torch.diag_embed(torch.tensor(numpy.sqrt(self.variance)))
return torch.distributions.MultivariateNormal(torch.tensor(self.mean), cov)
def generate_context(self, history):
return {}
@property
def additional_config(self):
return {"expectations_of_interventions": self.mean}
@property
def number_of_interventions(self):
return len(self.mean)
def observe_outcome(self, action, context):
treatment_index = action["treatment"]
return {"outcome": self.rng.normal(self.mean[treatment_index], numpy.sqrt(self.variance[treatment_index]))}
def __str__(self):
return f"NormalModel({self.mean, self.variance})"
generating_scenario_I = lambda patient_id: NormalModel(patient_id, mean=[0, 0], variance=[1,1])
generating_scenario_II = lambda patient_id: NormalModel(patient_id, mean=[1, 0], variance=[1,1])
generating_scenario_III = lambda patient_id: NormalModel(patient_id, mean=[2, 0], variance=[1,1])
generating_scenario_IV = lambda patient_id: NormalModel(patient_id, mean=[1, 0, 0], variance=[1,1,1])
generating_scenario_V = lambda patient_id: NormalModel(patient_id, mean=[2, 1, 0], variance=[1,1,1])
In [63]:
# Inference Model
inference_model = lambda: NormalKnownVariance(prior_mean=0, prior_variance=1, variance=1)
In [64]:
# Policies
fixed_policy = BlockPolicy(
block_length = block_length,
internal_policy = FixedPolicy(
inference_model = inference_model(),
block_length = block_length,
randomize = True,
)
)
SH_policy = BlockPolicy(
block_length = block_length,
internal_policy = SequentialHalving(
inference_model = inference_model(),
block_length = block_length,
length = length,
)
)
etc_policy = BlockPolicy(
block_length = block_length,
internal_policy = ExploreThenCommit(
inference_model = inference_model(),
block_length = block_length,
exploration_length=4,
randomize = True,
)
)
thompson_sampling_policy = BlockPolicy(
block_length = block_length,
internal_policy = ThompsonSampling(
inference_model=inference_model(),
)
)
ucb_policy = BlockPolicy(
block_length = block_length,
internal_policy = UpperConfidenceBound(
inference_model=inference_model(),
epsilon=0.05,
)
)
In [65]:
# Full crossover study
study_designs = {
"n_patients": [number_of_patients],
"policy": [SH_policy, fixed_policy, etc_policy, thompson_sampling_policy, ucb_policy],
# "policy": [two_means_ts, linear_regression_ts, constant_ts],
"model_from_patient_id": [
generating_scenario_I, generating_scenario_II, generating_scenario_III, generating_scenario_IV, generating_scenario_V,
],
}
configurations = generate_configuration_cross_product(study_designs)
In [66]:
ENABLE_SIMULATION = False
if ENABLE_SIMULATION:
print("Simulation was enabled")
else:
print("Simulation was disabled")
Simulation was disabled
In [67]:
%timeit
if ENABLE_SIMULATION:
calculated_series, config_to_simulation_data = simulate_configurations(
configurations, length
)
In [68]:
if ENABLE_SIMULATION:
write_to_disk("data/2024-02-11-mt_case_study_1_data.json", [calculated_series, config_to_simulation_data])
In [69]:
if not ENABLE_SIMULATION:
calculated_series, config_to_simulation_data = load_from_disk("data/2024-02-11-mt_case_study_1_data.json")
In [70]:
def debug_data_to_torch_distribution(debug_data):
mean = debug_data["mean"]
# + the true variance of 1
variance = numpy.array(debug_data["variance"]) + 1
cov = torch.diag_embed(torch.tensor(numpy.sqrt(variance)))
return torch.distributions.MultivariateNormal(torch.tensor(mean), cov)
def data_to_true_distribution(data):
mean = data.additional_config["expectations_of_interventions"]
cov = torch.eye(len(mean))
return torch.distributions.MultivariateNormal(torch.tensor(mean), cov)
metrics = [
SimpleRegretWithMean(),
CumulativeRegret(),
KLDivergence(data_to_true_distribution = data_to_true_distribution, debug_data_to_posterior_distribution=debug_data_to_torch_distribution)
]
model_mapping = {
"NormalModel(([0, 0], [1, 1]))": "I",
"NormalModel(([1, 0], [1, 1]))": "II",
"NormalModel(([2, 0], [1, 1]))": "III",
"NormalModel(([1, 0, 0], [1, 1, 1]))": "IV",
"NormalModel(([2, 1, 0], [1, 1, 1]))": "V",
"PooledNormalModel(([1, 0], [0.8, 0.8], [1, 1]))": "VI",
"PooledNormalModel(([1, 0], [1.5, 1.5], [1, 1]))": "VII",
}
policy_mapping = {
"BlockPolicy(FixedPolicy)": "Fixed",
"BlockPolicy(ThompsonSampling(NormalKnownVariance(0, 1, 1)))": "TS",
"BlockPolicy(UpperConfidenceBound(0.05 epsilon, NormalKnownVariance(0, 1, 1)))": "UCB",
"BlockPolicy(ExploreThenCommit(4,NormalKnownVariance(0, 1, 1)))": "ETC",
"BlockPolicy(SequentialHalving(30))": "SH",
"BlockPolicy(ThompsonSampling(PooledNormalWithKnownVariance))": "Pooled TS",
"BlockPolicy(UpperConfidenceBound(0.05 epsilon, PooledNormalWithKnownVariance))": "Pooled UCB",
}
df = SeriesOfSimulationsData.score_data(
[s["result"] for s in calculated_series], metrics, {"model": lambda x: model_mapping[x], "policy": lambda x: policy_mapping[x]}
)
filtered_df = df.loc[df["t"] == df["t"].max()]
filtered_df
groupby_columns = ["model", "policy"]
pivoted_df = filtered_df.pivot(
index=["model", "policy", "simulation", "patient_id"],
columns="metric",
values="score",
)
table = pivoted_df.groupby(groupby_columns).agg(['mean', 'std'])
policy_ordering = ["Fixed", "ETC", "SH", "UCB", "TS", "Pooled UCB", "Pooled TS"]
# Convert the 'policy' column in the MultiIndex to a Categorical type with the specified order
table = table.reset_index()
table['policy'] = pd.Categorical(table['policy'], categories=policy_ordering, ordered=True)
# Sort the DataFrame first by 'model' then by the now-ordered 'policy'
sorted_table = table.sort_values(by=['model', 'policy']).set_index(groupby_columns)
sorted_table
Out[70]:
| metric | Cumulative Regret (outcome) | KL Divergence | Simple Regret With Mean | ||||
|---|---|---|---|---|---|---|---|
| mean | std | mean | std | mean | std | ||
| model | policy | ||||||
| I | Fixed | -0.00138 | 5.327055 | 0.062780 | 0.068680 | 0.00 | 0.000000 |
| ETC | -0.00138 | 5.327055 | 0.057388 | 0.050710 | 0.00 | 0.000000 | |
| SH | -0.00138 | 5.327055 | 0.062780 | 0.068680 | 0.00 | 0.000000 | |
| UCB | -0.00138 | 5.327055 | 0.096489 | 0.119915 | 0.00 | 0.000000 | |
| TS | -0.00138 | 5.327055 | 0.086121 | 0.115799 | 0.00 | 0.000000 | |
| II | Fixed | -15.00138 | 5.327055 | 0.063840 | 0.064014 | 0.00 | 0.000000 |
| ETC | -20.00138 | 5.327055 | 0.061411 | 0.055619 | 0.00 | 0.000000 | |
| SH | -15.00138 | 5.327055 | 0.066572 | 0.065142 | 0.00 | 0.000000 | |
| UCB | -26.30138 | 6.842714 | 0.089736 | 0.126130 | 0.01 | 0.100000 | |
| TS | -22.95138 | 7.025298 | 0.071823 | 0.089907 | 0.02 | 0.140705 | |
| III | Fixed | -30.00138 | 5.327055 | 0.072629 | 0.077475 | 0.00 | 0.000000 |
| ETC | -40.00138 | 5.327055 | 0.077131 | 0.075092 | 0.00 | 0.000000 | |
| SH | -30.00138 | 5.327055 | 0.071373 | 0.062533 | 0.00 | 0.000000 | |
| UCB | -55.40138 | 7.316342 | 0.063888 | 0.070196 | 0.00 | 0.000000 | |
| TS | -51.20138 | 13.717214 | 0.137292 | 0.351123 | 0.04 | 0.281411 | |
| IV | Fixed | -10.00138 | 5.327055 | 0.160076 | 0.145236 | 0.06 | 0.238683 |
| ETC | -15.80138 | 6.891687 | 0.164202 | 0.135550 | 0.06 | 0.238683 | |
| SH | -12.15138 | 5.370252 | 0.167853 | 0.146736 | 0.06 | 0.238683 | |
| UCB | -23.65138 | 7.893791 | 0.148453 | 0.151855 | 0.04 | 0.196946 | |
| TS | -17.55138 | 8.710768 | 0.200423 | 0.195544 | 0.10 | 0.301511 | |
| V | Fixed | -30.00138 | 5.327055 | 0.184690 | 0.146695 | 0.04 | 0.196946 |
| ETC | -39.80138 | 7.436372 | 0.174190 | 0.139336 | 0.01 | 0.100000 | |
| SH | -36.75138 | 6.198618 | 0.164827 | 0.151292 | 0.05 | 0.219043 | |
| UCB | -48.35138 | 12.561813 | 0.935611 | 0.661852 | 0.25 | 0.435194 | |
| TS | -43.80138 | 14.777462 | 0.832389 | 0.719263 | 0.23 | 0.422953 | |
In [71]:
with open('mt_resources/5-optimization/table.tex', 'w') as file:
str = sorted_table.style.format(precision=2).to_latex()
print(str)
file.write(str)
\begin{tabular}{llrrrrrr}
& metric & \multicolumn{2}{r}{Cumulative Regret (outcome)} & \multicolumn{2}{r}{KL Divergence} & \multicolumn{2}{r}{Simple Regret With Mean} \\
& & mean & std & mean & std & mean & std \\
model & policy & & & & & & \\
\multirow[c]{5}{*}{I} & Fixed & -0.00 & 5.33 & 0.06 & 0.07 & 0.00 & 0.00 \\
& ETC & -0.00 & 5.33 & 0.06 & 0.05 & 0.00 & 0.00 \\
& SH & -0.00 & 5.33 & 0.06 & 0.07 & 0.00 & 0.00 \\
& UCB & -0.00 & 5.33 & 0.10 & 0.12 & 0.00 & 0.00 \\
& TS & -0.00 & 5.33 & 0.09 & 0.12 & 0.00 & 0.00 \\
\multirow[c]{5}{*}{II} & Fixed & -15.00 & 5.33 & 0.06 & 0.06 & 0.00 & 0.00 \\
& ETC & -20.00 & 5.33 & 0.06 & 0.06 & 0.00 & 0.00 \\
& SH & -15.00 & 5.33 & 0.07 & 0.07 & 0.00 & 0.00 \\
& UCB & -26.30 & 6.84 & 0.09 & 0.13 & 0.01 & 0.10 \\
& TS & -22.95 & 7.03 & 0.07 & 0.09 & 0.02 & 0.14 \\
\multirow[c]{5}{*}{III} & Fixed & -30.00 & 5.33 & 0.07 & 0.08 & 0.00 & 0.00 \\
& ETC & -40.00 & 5.33 & 0.08 & 0.08 & 0.00 & 0.00 \\
& SH & -30.00 & 5.33 & 0.07 & 0.06 & 0.00 & 0.00 \\
& UCB & -55.40 & 7.32 & 0.06 & 0.07 & 0.00 & 0.00 \\
& TS & -51.20 & 13.72 & 0.14 & 0.35 & 0.04 & 0.28 \\
\multirow[c]{5}{*}{IV} & Fixed & -10.00 & 5.33 & 0.16 & 0.15 & 0.06 & 0.24 \\
& ETC & -15.80 & 6.89 & 0.16 & 0.14 & 0.06 & 0.24 \\
& SH & -12.15 & 5.37 & 0.17 & 0.15 & 0.06 & 0.24 \\
& UCB & -23.65 & 7.89 & 0.15 & 0.15 & 0.04 & 0.20 \\
& TS & -17.55 & 8.71 & 0.20 & 0.20 & 0.10 & 0.30 \\
\multirow[c]{5}{*}{V} & Fixed & -30.00 & 5.33 & 0.18 & 0.15 & 0.04 & 0.20 \\
& ETC & -39.80 & 7.44 & 0.17 & 0.14 & 0.01 & 0.10 \\
& SH & -36.75 & 6.20 & 0.16 & 0.15 & 0.05 & 0.22 \\
& UCB & -48.35 & 12.56 & 0.94 & 0.66 & 0.25 & 0.44 \\
& TS & -43.80 & 14.78 & 0.83 & 0.72 & 0.23 & 0.42 \\
\end{tabular}
In [72]:
def rename_df(df):
df["policy_#_metric_#_model_p"] = df["policy"].apply(lambda x: policy_mapping[x])
df['policy_#_metric_#_model_p'] = pd.Categorical(df['policy_#_metric_#_model_p'], categories=policy_ordering[0:5], ordered=True)
return df
ax = SeriesOfSimulationsData.plot_lines(
[s["result"] for s in calculated_series if s["configuration"]["model"] == "NormalModel(([2, 1, 0], [1, 1, 1]))"],
[
CumulativeRegret(),
],
legend_position=(0.8,1.0),
process_df = rename_df,
)
plt.ylabel('Regret')
seaborn.move_legend(ax, "upper right", title=None)
plt.savefig("mt_resources/5-optimization/02_cumulative_regret.pdf", bbox_inches="tight")
In [73]:
ax = SeriesOfSimulationsData.plot_lines(
[s["result"] for s in calculated_series if s["configuration"]["model"] == "NormalModel(([2, 1, 0], [1, 1, 1]))"],
[
SimpleRegretWithMean(),
],
process_df = rename_df,
)
plt.ylabel('Simple Regret')
seaborn.move_legend(ax, "upper right", title=None)
plt.savefig("mt_resources/5-optimization/02_simple_regret.pdf", bbox_inches="tight")
In [74]:
ax = SeriesOfSimulationsData.plot_lines(
[s["result"] for s in calculated_series if s["configuration"]["model"] == "NormalModel(([2, 1, 0], [1, 1, 1]))"],
[
KLDivergence(data_to_true_distribution = data_to_true_distribution, debug_data_to_posterior_distribution=debug_data_to_torch_distribution)
],
process_df = rename_df,
)
plt.ylabel('KL Divergence')
seaborn.move_legend(ax, "upper right", title=None)
plt.savefig("mt_resources/5-optimization/02_kld.pdf", bbox_inches="tight")
In [75]:
plot_allocations_for_calculated_series(calculated_series)
/opt/homebrew/Caskroom/miniconda/base/envs/mt/lib/python3.11/site-packages/holoviews/plotting/bokeh/plot.py:987: UserWarning: found multiple competing values for 'toolbar.active_drag' property; using the latest value layout_plot = gridplot( /opt/homebrew/Caskroom/miniconda/base/envs/mt/lib/python3.11/site-packages/holoviews/plotting/bokeh/plot.py:987: UserWarning: found multiple competing values for 'toolbar.active_scroll' property; using the latest value layout_plot = gridplot(
Out[75]: